{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from math import ceil\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "import torch.optim as optim\n",
    "import torch.nn as nn\n",
    "\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "from utils.input_pipeline import get_image_folders\n",
    "from utils.training import train, optimization_step\n",
    "    \n",
    "torch.cuda.is_available()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "torch.backends.cudnn.benchmark = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create data iterators"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "batch_size = 64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "100000"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_folder, val_folder = get_image_folders()\n",
    "\n",
    "train_iterator = DataLoader(\n",
    "    train_folder, batch_size=batch_size, num_workers=4,\n",
    "    shuffle=True, pin_memory=True\n",
    ")\n",
    "\n",
    "val_iterator = DataLoader(\n",
    "    val_folder, batch_size=256, num_workers=4,\n",
    "    shuffle=False, pin_memory=True\n",
    ")\n",
    "\n",
    "# number of training samples\n",
    "train_size = len(train_folder.imgs)\n",
    "train_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10000"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# number of validation samples\n",
    "val_size = len(val_folder.imgs)\n",
    "val_size"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from densenet import DenseNet\n",
    "model = DenseNet()\n",
    "# load the model from step3\n",
    "model.load_state_dict(torch.load('model_step3.pytorch_state'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# create different parameter groups\n",
    "weights = [\n",
    "    p for n, p in model.named_parameters()\n",
    "    if 'conv' in n or 'classifier.weight' in n\n",
    "]\n",
    "biases = [model.classifier.bias]\n",
    "bn_weights = [\n",
    "    p for n, p in model.named_parameters()\n",
    "    if 'norm' in n and 'weight' in n\n",
    "]\n",
    "bn_biases = [\n",
    "    p for n, p in model.named_parameters()\n",
    "    if 'norm' in n and 'bias' in n\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "params = [\n",
    "    {'params': weights, 'weight_decay': 1e-4},\n",
    "    {'params': biases},\n",
    "    {'params': bn_weights},\n",
    "    {'params': bn_biases}\n",
    "]\n",
    "optimizer = optim.SGD(params, lr=1e-4, momentum=0.9, nesterov=True)\n",
    "loss = nn.CrossEntropyLoss().cuda()\n",
    "model = model.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1563"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "n_epochs = 5\n",
    "n_batches = ceil(train_size/batch_size)\n",
    "\n",
    "# total number of batches in the train set\n",
    "n_batches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0  1.407 1.100  0.642 0.715  0.862 0.904  2095.310\n",
      "1  1.373 1.084  0.650 0.721  0.866 0.904  2081.717\n",
      "2  1.343 1.063  0.656 0.726  0.870 0.908  2085.719\n",
      "3  1.314 1.060  0.662 0.726  0.874 0.907  2087.037\n",
      "4  1.288 1.042  0.669 0.728  0.879 0.908  2085.568\n",
      "CPU times: user 2h 38min 5s, sys: 15min 52s, total: 2h 53min 58s\n",
      "Wall time: 2h 53min 55s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "def optimization_step_fn(model, loss, x_batch, y_batch):\n",
    "    return optimization_step(model, loss, x_batch, y_batch, optimizer)\n",
    "\n",
    "all_losses = train(\n",
    "    model, loss, optimization_step_fn,\n",
    "    train_iterator, val_iterator, n_epochs\n",
    ")\n",
    "# epoch logloss  accuracy    top5_accuracy time  (first value: train, second value: val)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Save"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model.cpu();\n",
    "torch.save(model.state_dict(), 'model_step4.pytorch_state')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
